import glob

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns

import os


def auto_correlation(x):
    x = x - np.mean(x)
    result = np.correlate(x, x, mode='full')
    return result[result.size // 2:] / result[result.size // 2]


def set_axis_ticker(ax):
    formatter = ticker.ScalarFormatter(useMathText=True)
    formatter.set_powerlimits((0, 0))  # always use scientific notation
    ax.xaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=5))

    ax.yaxis.set_major_formatter(formatter)
    ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=5))


def plot_autocorrelation():
    save_dir = "figures/autocorrelation"
    os.makedirs(save_dir, exist_ok=True)

    raw_path = "data/raw/PRSA_Data_Aotizhongxin_20130301-20170228.csv"
    processed_path = "data/preprocessed/Aotizhongxin.csv"

    raw_data = pd.read_csv(raw_path,
                           index_col=0,
                           header=0,)

    processed_data = pd.read_csv(processed_path,
                                 index_col='Datetime',
                                 header=0,
                                 parse_dates=["Datetime"],)

    processed_features = processed_data.columns
    raw_features = [features.upper() for features in processed_features]

    for processed_feature, raw_feature in zip(processed_features, raw_features):
        if processed_feature == "station":
            continue
        seasonality = raw_data[raw_feature]
        processed_seasonality = processed_data[processed_feature]
        processed_seasonality = processed_seasonality.reset_index(drop=True)
        processed_auto_correlation = auto_correlation(processed_seasonality)

        # auto correlation after removing the seasonality
        fig, ax = plt.subplots(figsize=(4, 3))
        ax.plot(processed_auto_correlation[100:200])
        plt.title(raw_feature)
        plt.xlabel("Time lag")

        set_axis_ticker(ax)

        fig.tight_layout()
        fig.canvas.draw()
        fig.savefig(f"{save_dir}/{processed_feature}_processed_auto_correlation.pdf",
                    bbox_inches="tight")
        plt.close(fig)

        # adjusted seasonality
        fig, ax = plt.subplots(figsize=(4, 3))
        ax.plot(processed_seasonality)
        plt.title(raw_feature)

        set_axis_ticker(ax)
        fig.tight_layout()
        fig.canvas.draw()
        fig.savefig(f"{save_dir}/{processed_feature}_processed_seasonality.pdf",
                    bbox_inches="tight")
        plt.close(fig)

        # original seasonality
        fig, ax = plt.subplots(figsize=(4, 3))
        ax.plot(seasonality)
        plt.title(raw_feature)

        set_axis_ticker(ax)
        fig.tight_layout()
        fig.canvas.draw()
        fig.savefig(f"{save_dir}/{processed_feature}_seasonality.pdf",
                    bbox_inches="tight")
        plt.close(fig)


def plot_corr_matrix():
    save_dir = "figures/corr_matrix"
    os.makedirs(save_dir, exist_ok=True)

    data_path = "data/preprocessed"

    for file_name in glob.glob(f"{data_path}/*.csv"):
        df = pd.read_csv(file_name,
                         index_col="Datetime",
                         header=0,
                         parse_dates=["Datetime"],)

        fig, ax = plt.subplots(figsize=(4, 3))
        station = df["station"].iloc[0]
        plt.title(station)
        df = df.drop(columns=["station"])
        sns.heatmap(df.corr())
        plt.savefig(f"{save_dir}/{station}_corr_matrix.pdf",
                    bbox_inches="tight")
        plt.close(fig)


if __name__ == '__main__':
    plot_autocorrelation()
    plot_corr_matrix()